#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/Atomic.cuh>
#include <c10/cuda/CUDAStream.h>
#include <cooperative_groups.h>

#include "Cameras.cuh"
#include "Common.h"
#include "Rasterization.h"
#include "Utils.cuh"

namespace gsplat {

    namespace cg = cooperative_groups;

    template <uint32_t CDIM, typename scalar_t>
    __global__ void rasterize_to_pixels_from_world_3dgs_bwd_kernel(
        const uint32_t C,
        const uint32_t N,
        const uint32_t n_isects,
        const bool packed,
        // fwd inputs
        const vec3* __restrict__ means,           // [N, 3]
        const vec4* __restrict__ quats,           // [N, 4]
        const vec3* __restrict__ scales,          // [N, 3]
        const scalar_t* __restrict__ colors,      // [C, N, CDIM] or [nnz, CDIM]
        const scalar_t* __restrict__ opacities,   // [C, N] or [nnz]
        const scalar_t* __restrict__ backgrounds, // [C, CDIM] or [nnz, CDIM]
        const bool* __restrict__ masks,           // [C, tile_height, tile_width]
        const uint32_t image_width,
        const uint32_t image_height,
        const uint32_t tile_size,
        const uint32_t tile_width,
        const uint32_t tile_height,
        // camera model
        const scalar_t* __restrict__ viewmats0, // [C, 4, 4]
        const scalar_t* __restrict__ viewmats1, // [C, 4, 4] optional for rolling shutter
        const scalar_t* __restrict__ Ks,        // [C, 3, 3]
        const CameraModelType camera_model_type,
        // uncented transform
        const UnscentedTransformParameters ut_params,
        const ShutterType rs_type,
        const scalar_t* __restrict__ radial_coeffs,     // [C, 6] or [C, 4] optional
        const scalar_t* __restrict__ tangential_coeffs, // [C, 2] optional
        const scalar_t* __restrict__ thin_prism_coeffs, // [C, 2] optional
        // intersections
        const int32_t* __restrict__ tile_offsets, // [C, tile_height, tile_width]
        const int32_t* __restrict__ flatten_ids,  // [n_isects]
        // fwd outputs
        const scalar_t* __restrict__ render_alphas, // [C, image_height, image_width, 1]
        const int32_t* __restrict__ last_ids,       // [C, image_height, image_width]
        // grad outputs
        const scalar_t* __restrict__ v_render_colors, // [C, image_height,
                                                      // image_width, CDIM]
        const scalar_t* __restrict__ v_render_alphas, // [C, image_height, image_width, 1]
        // grad inputs
        vec3* __restrict__ v_means,        // [N, 3]
        vec4* __restrict__ v_quats,        // [N, 4]
        vec3* __restrict__ v_scales,       // [N, 3]
        scalar_t* __restrict__ v_colors,   // [C, N, CDIM] or [nnz, CDIM]
        scalar_t* __restrict__ v_opacities // [C, N] or [nnz]
    ) {
        auto block = cg::this_thread_block();
        uint32_t cid = block.group_index().x;
        uint32_t tile_id =
            block.group_index().y * tile_width + block.group_index().z;
        uint32_t i = block.group_index().y * tile_size + block.thread_index().y;
        uint32_t j = block.group_index().z * tile_size + block.thread_index().x;

        tile_offsets += cid * tile_height * tile_width;
        render_alphas += cid * image_height * image_width;
        last_ids += cid * image_height * image_width;
        v_render_colors += cid * image_height * image_width * CDIM;
        v_render_alphas += cid * image_height * image_width;
        if (backgrounds != nullptr) {
            backgrounds += cid * CDIM;
        }
        if (masks != nullptr) {
            masks += cid * tile_height * tile_width;
        }

        // when the mask is provided, do nothing and return if
        // this tile is labeled as False
        if (masks != nullptr && !masks[tile_id]) {
            return;
        }

        const float px = (float)j + 0.5f;
        const float py = (float)i + 0.5f;
        // clamp this value to the last pixel
        const int32_t pix_id =
            min(i * image_width + j, image_width * image_height - 1);

        // Create rolling shutter parameter
        auto rs_params = RollingShutterParameters(
            viewmats0 + cid * 16,
            viewmats1 == nullptr ? nullptr : viewmats1 + cid * 16);
        // shift pointers to the current camera. note that glm is colume-major.
        const vec2 focal_length = {Ks[cid * 9 + 0], Ks[cid * 9 + 4]};
        const vec2 principal_point = {Ks[cid * 9 + 2], Ks[cid * 9 + 5]};

        // Create ray from pixel
        WorldRay ray;
        if (camera_model_type == CameraModelType::PINHOLE) {
            if (radial_coeffs == nullptr && tangential_coeffs == nullptr && thin_prism_coeffs == nullptr) {
                PerfectPinholeCameraModel::Parameters cm_params = {};
                cm_params.resolution = {image_width, image_height};
                cm_params.shutter_type = rs_type;
                cm_params.principal_point = {principal_point.x, principal_point.y};
                cm_params.focal_length = {focal_length.x, focal_length.y};
                PerfectPinholeCameraModel camera_model(cm_params);
                ray = camera_model.image_point_to_world_ray_shutter_pose(vec2(px, py), rs_params);
            } else {
                OpenCVPinholeCameraModel<>::Parameters cm_params = {};
                cm_params.resolution = {image_width, image_height};
                cm_params.shutter_type = rs_type;
                cm_params.principal_point = {principal_point.x, principal_point.y};
                cm_params.focal_length = {focal_length.x, focal_length.y};
                if (radial_coeffs != nullptr) {
                    cm_params.radial_coeffs = make_array<float, 6>(radial_coeffs + cid * 6);
                }
                if (tangential_coeffs != nullptr) {
                    cm_params.tangential_coeffs = make_array<float, 2>(tangential_coeffs + cid * 2);
                }
                if (thin_prism_coeffs != nullptr) {
                    cm_params.thin_prism_coeffs = make_array<float, 4>(thin_prism_coeffs + cid * 4);
                }
                OpenCVPinholeCameraModel camera_model(cm_params);
                ray = camera_model.image_point_to_world_ray_shutter_pose(vec2(px, py), rs_params);
            }
        } else if (camera_model_type == CameraModelType::FISHEYE) {
            OpenCVFisheyeCameraModel<>::Parameters cm_params = {};
            cm_params.resolution = {image_width, image_height};
            cm_params.shutter_type = rs_type;
            cm_params.principal_point = {principal_point.x, principal_point.y};
            cm_params.focal_length = {focal_length.x, focal_length.y};
            if (radial_coeffs != nullptr) {
                cm_params.radial_coeffs = make_array<float, 4>(radial_coeffs + cid * 4);
            }
            OpenCVFisheyeCameraModel camera_model(cm_params);
            ray = camera_model.image_point_to_world_ray_shutter_pose(vec2(px, py), rs_params);
        } else if (camera_model_type == CameraModelType::EQUIRECTANGULAR) {
            EquirectangularCameraModel::Parameters cm_params = {};
            cm_params.resolution = {image_width, image_height};
            cm_params.shutter_type = rs_type;
            EquirectangularCameraModel camera_model(cm_params);
            ray = camera_model.image_point_to_world_ray_shutter_pose(vec2(px, py), rs_params);
        } else {
            // should never reach here
            assert(false);
            return;
        }
        const vec3 ray_d = ray.ray_dir;
        const vec3 ray_o = ray.ray_org;

        // keep not rasterizing threads around for reading data
        bool done = (i < image_height && j < image_width) && ray.valid_flag;

        // have all threads in tile process the same gaussians in batches
        // first collect gaussians between range.x and range.y in batches
        // which gaussians to look through in this tile
        int32_t range_start = tile_offsets[tile_id];
        int32_t range_end =
            (cid == C - 1) && (tile_id == tile_width * tile_height - 1)
                ? n_isects
                : tile_offsets[tile_id + 1];
        const uint32_t block_size = block.size();
        const uint32_t num_batches =
            (range_end - range_start + block_size - 1) / block_size;

        extern __shared__ int s[];
        int32_t* id_batch = (int32_t*)s; // [block_size]
        vec4* xyz_opacity_batch =
            reinterpret_cast<vec4*>(&id_batch[block_size]); // [block_size]
        vec3* scale_batch =
            reinterpret_cast<vec3*>(&xyz_opacity_batch[block_size]); // [block_size]
        vec4* quat_batch =
            reinterpret_cast<vec4*>(&scale_batch[block_size]); // [block_size]
        float* rgbs_batch =
            (float*)&quat_batch[block_size]; // [block_size * CDIM]

        // this is the T AFTER the last gaussian in this pixel
        float T_final = 1.0f - render_alphas[pix_id];
        float T = T_final;
        // the contribution from gaussians behind the current one
        float buffer[CDIM] = {0.f};
        // index of last gaussian to contribute to this pixel
        const int32_t bin_final = done ? last_ids[pix_id] : 0;

        // df/d_out for this pixel
        float v_render_c[CDIM];
#pragma unroll
        for (uint32_t k = 0; k < CDIM; ++k) {
            v_render_c[k] = v_render_colors[pix_id * CDIM + k];
        }
        const float v_render_a = v_render_alphas[pix_id];

        // collect and process batches of gaussians
        // each thread loads one gaussian at a time before rasterizing
        const uint32_t tr = block.thread_rank();
        cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
        const int32_t warp_bin_final =
            cg::reduce(warp, bin_final, cg::greater<int>());
        for (uint32_t b = 0; b < num_batches; ++b) {
            // resync all threads before writing next batch of shared mem
            block.sync();

            // each thread fetch 1 gaussian from back to front
            // 0 index will be furthest back in batch
            // index of gaussian to load
            // batch end is the index of the last gaussian in the batch
            // These values can be negative so must be int32 instead of uint32
            const int32_t batch_end = range_end - 1 - block_size * b;
            const int32_t batch_size = min(block_size, batch_end + 1 - range_start);
            const int32_t idx = batch_end - tr;
            if (idx >= range_start) {
                // TODO: only support 1 camera for now so it is ok to abuse the index.
                int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz]
                id_batch[tr] = g;
                const vec3 xyz = means[g];
                const float opac = opacities[g];
                xyz_opacity_batch[tr] = {xyz.x, xyz.y, xyz.z, opac};
                scale_batch[tr] = scales[g];
                quat_batch[tr] = quats[g];
#pragma unroll
                for (uint32_t k = 0; k < CDIM; ++k) {
                    rgbs_batch[tr * CDIM + k] = colors[g * CDIM + k];
                }
            }
            // wait for other threads to collect the gaussians in batch
            block.sync();
            // process gaussians in the current batch for this pixel
            // 0 index is the furthest back gaussian in the batch
            for (uint32_t t = max(0, batch_end - warp_bin_final); t < batch_size;
                 ++t) {
                bool valid = done;
                if (batch_end - t > bin_final) {
                    valid = 0;
                }
                float alpha;
                float opac;
                float vis;

                mat3 R, S;
                vec3 xyz;
                vec3 scale;
                vec4 quat;
                mat3 Mt;
                vec3 o_minus_mu, gro, grd, grd_n, gcrod;
                float grayDist, power;
                if (valid) {
                    const vec4 xyz_opac = xyz_opacity_batch[t];
                    opac = xyz_opac[3];
                    xyz = {xyz_opac[0], xyz_opac[1], xyz_opac[2]};
                    scale = scale_batch[t];
                    quat = quat_batch[t];

                    R = quat_to_rotmat(quat);
                    S = mat3(
                        1.0f / scale[0],
                        0.f,
                        0.f,
                        0.f,
                        1.0f / scale[1],
                        0.f,
                        0.f,
                        0.f,
                        1.0f / scale[2]);
                    Mt = glm::transpose(R * S);
                    o_minus_mu = ray_o - xyz;
                    gro = Mt * o_minus_mu;
                    grd = Mt * ray_d;
                    grd_n = safe_normalize(grd);
                    gcrod = glm::cross(grd_n, gro);
                    grayDist = glm::dot(gcrod, gcrod);
                    power = -0.5f * grayDist;

                    vis = __expf(power);
                    alpha = min(0.999f, opac * vis);
                    if (power > 0.f || alpha < 1.f / 255.f) {
                        valid = false;
                    }
                }

                // if all threads are inactive in this warp, skip this loop
                if (!warp.any(valid)) {
                    continue;
                }
                float v_rgb_local[CDIM] = {0.f};
                vec3 v_mean_local = {0.f, 0.f, 0.f};
                vec3 v_scale_local = {0.f, 0.f, 0.f};
                vec4 v_quat_local = {0.f, 0.f, 0.f, 0.f};
                float v_opacity_local = 0.f;
                // initialize everything to 0, only set if the lane is valid
                if (valid) {
                    // compute the current T for this gaussian
                    float ra = 1.0f / (1.0f - alpha);
                    T *= ra;
                    // update v_rgb for this gaussian
                    const float fac = alpha * T;
#pragma unroll
                    for (uint32_t k = 0; k < CDIM; ++k) {
                        v_rgb_local[k] = fac * v_render_c[k];
                    }
                    // contribution from this pixel
                    float v_alpha = 0.f;
#pragma unroll
                    for (uint32_t k = 0; k < CDIM; ++k) {
                        v_alpha += (rgbs_batch[t * CDIM + k] * T - buffer[k] * ra) *
                                   v_render_c[k];
                    }

                    v_alpha += T_final * ra * v_render_a;
                    // contribution from background pixel
                    if (backgrounds != nullptr) {
                        float accum = 0.f;
#pragma unroll
                        for (uint32_t k = 0; k < CDIM; ++k) {
                            accum += backgrounds[k] * v_render_c[k];
                        }
                        v_alpha += -T_final * ra * accum;
                    }

                    if (opac * vis <= 0.999f) {
                        const float v_vis = opac * v_alpha;
                        float v_gradDist = -0.5f * vis * v_vis;
                        vec3 v_gcrod = 2.0f * v_gradDist * gcrod;
                        vec3 v_grd_n = -glm::cross(v_gcrod, gro);
                        vec3 v_gro = glm::cross(v_gcrod, grd_n);
                        vec3 v_grd = safe_normalize_bw(grd, v_grd_n);
                        mat3 v_Mt = glm::outerProduct(v_grd, ray_d) +
                                    glm::outerProduct(v_gro, o_minus_mu);
                        vec3 v_o_minus_mu = glm::transpose(Mt) * v_gro;

                        v_mean_local += -v_o_minus_mu;
                        quat_scale_to_preci_half_vjp(
                            quat, scale, R, glm::transpose(v_Mt), v_quat_local, v_scale_local);
                        v_opacity_local = vis * v_alpha;
                    }

#pragma unroll
                    for (uint32_t k = 0; k < CDIM; ++k) {
                        buffer[k] += rgbs_batch[t * CDIM + k] * fac;
                    }
                }
                warpSum<CDIM>(v_rgb_local, warp);
                warpSum(v_mean_local, warp);
                warpSum(v_scale_local, warp);
                warpSum(v_quat_local, warp);
                warpSum(v_opacity_local, warp);
                if (warp.thread_rank() == 0) {
                    int32_t g = id_batch[t]; // flatten index in [C * N] or [nnz]
                    float* v_rgb_ptr = (float*)(v_colors) + CDIM * g;
#pragma unroll
                    for (uint32_t k = 0; k < CDIM; ++k) {
                        gpuAtomicAdd(v_rgb_ptr + k, v_rgb_local[k]);
                    }

                    float* v_mean_ptr = (float*)(v_means) + 3 * g;
                    gpuAtomicAdd(v_mean_ptr, v_mean_local.x);
                    gpuAtomicAdd(v_mean_ptr + 1, v_mean_local.y);
                    gpuAtomicAdd(v_mean_ptr + 2, v_mean_local.z);

                    float* v_scale_ptr = (float*)(v_scales) + 3 * g;
                    gpuAtomicAdd(v_scale_ptr, v_scale_local.x);
                    gpuAtomicAdd(v_scale_ptr + 1, v_scale_local.y);
                    gpuAtomicAdd(v_scale_ptr + 2, v_scale_local.z);

                    float* v_quat_ptr = (float*)(v_quats) + 4 * g;
                    gpuAtomicAdd(v_quat_ptr, v_quat_local.x);
                    gpuAtomicAdd(v_quat_ptr + 1, v_quat_local.y);
                    gpuAtomicAdd(v_quat_ptr + 2, v_quat_local.z);
                    gpuAtomicAdd(v_quat_ptr + 3, v_quat_local.w);

                    gpuAtomicAdd(v_opacities + g, v_opacity_local);
                }
            }
        }
    }

    template <uint32_t CDIM>
    void launch_rasterize_to_pixels_from_world_3dgs_bwd_kernel(
        // Gaussian parameters
        const at::Tensor means,                     // [N, 3]
        const at::Tensor quats,                     // [N, 4]
        const at::Tensor scales,                    // [N, 3]
        const at::Tensor colors,                    // [C, N, 3] or [nnz, 3]
        const at::Tensor opacities,                 // [C, N] or [nnz]
        const at::optional<at::Tensor> backgrounds, // [C, 3]
        const at::optional<at::Tensor> masks,       // [C, tile_height, tile_width]
        // image size
        const uint32_t image_width,
        const uint32_t image_height,
        const uint32_t tile_size,
        // camera
        const at::Tensor viewmats0,               // [C, 4, 4]
        const at::optional<at::Tensor> viewmats1, // [C, 4, 4] optional for rolling shutter
        const at::Tensor Ks,                      // [C, 3, 3]
        const CameraModelType camera_model,
        // uncented transform
        const UnscentedTransformParameters ut_params,
        ShutterType rs_type,
        const at::optional<at::Tensor> radial_coeffs,     // [C, 6] or [C, 4] optional
        const at::optional<at::Tensor> tangential_coeffs, // [C, 2] optional
        const at::optional<at::Tensor> thin_prism_coeffs, // [C, 2] optional
        // intersections
        const at::Tensor tile_offsets, // [C, tile_height, tile_width]
        const at::Tensor flatten_ids,  // [n_isects]
        // forward outputs
        const at::Tensor render_alphas, // [C, image_height, image_width, 1]
        const at::Tensor last_ids,      // [C, image_height, image_width]
        // gradients of outputs
        const at::Tensor v_render_colors, // [C, image_height, image_width, 3]
        const at::Tensor v_render_alphas, // [C, image_height, image_width, 1]
        // outputs
        at::Tensor v_means,    // [N, 3]
        at::Tensor v_quats,    // [N, 4]
        at::Tensor v_scales,   // [N, 3]
        at::Tensor v_colors,   // [C, N, 3] or [nnz, 3]
        at::Tensor v_opacities // [C, N] or [nnz]
    ) {
        bool packed = opacities.dim() == 1;
        assert(packed == false); // only support non-packed for now

        uint32_t C = tile_offsets.size(0);       // number of cameras
        uint32_t N = packed ? 0 : means.size(0); // number of gaussians
        uint32_t tile_height = tile_offsets.size(1);
        uint32_t tile_width = tile_offsets.size(2);
        uint32_t n_isects = flatten_ids.size(0);

        // Each block covers a tile on the image. In total there are
        // C * tile_height * tile_width blocks.
        dim3 threads = {tile_size, tile_size, 1};
        dim3 grid = {C, tile_height, tile_width};

        int64_t shmem_size =
            tile_size * tile_size *
            (sizeof(int32_t) + sizeof(vec4) + sizeof(vec3) + sizeof(vec4) + sizeof(float) * CDIM);

        if (n_isects == 0) {
            // skip the kernel launch if there are no elements
            return;
        }

        // TODO: an optimization can be done by passing the actual number of
        // channels into the kernel functions and avoid necessary global memory
        // writes. This requires moving the channel padding from python to C side.
        if (cudaFuncSetAttribute(
                rasterize_to_pixels_from_world_3dgs_bwd_kernel<CDIM, float>,
                cudaFuncAttributeMaxDynamicSharedMemorySize,
                shmem_size) != cudaSuccess) {
            AT_ERROR(
                "Failed to set maximum shared memory size (requested ",
                shmem_size,
                " bytes), try lowering tile_size.");
        }

        rasterize_to_pixels_from_world_3dgs_bwd_kernel<CDIM, float>
            <<<grid, threads, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
                C,
                N,
                n_isects,
                packed,
                reinterpret_cast<vec3*>(means.data_ptr<float>()),
                reinterpret_cast<vec4*>(quats.data_ptr<float>()),
                reinterpret_cast<vec3*>(scales.data_ptr<float>()),
                colors.data_ptr<float>(),
                opacities.data_ptr<float>(),
                backgrounds.has_value() ? backgrounds.value().data_ptr<float>()
                                        : nullptr,
                masks.has_value() ? masks.value().data_ptr<bool>() : nullptr,
                image_width,
                image_height,
                tile_size,
                tile_width,
                tile_height,
                // camera model
                viewmats0.data_ptr<float>(),
                viewmats1.has_value() ? viewmats1.value().data_ptr<float>()
                                      : nullptr,
                Ks.data_ptr<float>(),
                camera_model,
                // uncented transform
                ut_params,
                rs_type,
                radial_coeffs.has_value() ? radial_coeffs.value().data_ptr<float>()
                                          : nullptr,
                tangential_coeffs.has_value()
                    ? tangential_coeffs.value().data_ptr<float>()
                    : nullptr,
                thin_prism_coeffs.has_value()
                    ? thin_prism_coeffs.value().data_ptr<float>()
                    : nullptr,
                // intersections
                tile_offsets.data_ptr<int32_t>(),
                flatten_ids.data_ptr<int32_t>(),
                render_alphas.data_ptr<float>(),
                last_ids.data_ptr<int32_t>(),
                v_render_colors.data_ptr<float>(),
                v_render_alphas.data_ptr<float>(),
                // outputs
                reinterpret_cast<vec3*>(v_means.data_ptr<float>()),
                reinterpret_cast<vec4*>(v_quats.data_ptr<float>()),
                reinterpret_cast<vec3*>(v_scales.data_ptr<float>()),
                v_colors.data_ptr<float>(),
                v_opacities.data_ptr<float>());
    }

// Explicit Instantiation: this should match how it is being called in .cpp
// file.
// TODO: this is slow to compile, can we do something about it?
#define __INS__(CDIM)                                                          \
    template void launch_rasterize_to_pixels_from_world_3dgs_bwd_kernel<CDIM>( \
        const at::Tensor means,                                                \
        const at::Tensor quats,                                                \
        const at::Tensor scales,                                               \
        const at::Tensor colors,                                               \
        const at::Tensor opacities,                                            \
        const at::optional<at::Tensor> backgrounds,                            \
        const at::optional<at::Tensor> masks,                                  \
        const uint32_t image_width,                                            \
        const uint32_t image_height,                                           \
        const uint32_t tile_size,                                              \
        const at::Tensor viewmats0,                                            \
        const at::optional<at::Tensor> viewmats1,                              \
        const at::Tensor Ks,                                                   \
        const CameraModelType camera_model,                                    \
        const UnscentedTransformParameters ut_params,                          \
        const ShutterType rs_type,                                             \
        const at::optional<at::Tensor> radial_coeffs,                          \
        const at::optional<at::Tensor> tangential_coeffs,                      \
        const at::optional<at::Tensor> thin_prism_coeffs,                      \
        const at::Tensor tile_offsets,                                         \
        const at::Tensor flatten_ids,                                          \
        const at::Tensor render_alphas,                                        \
        const at::Tensor last_ids,                                             \
        const at::Tensor v_render_colors,                                      \
        const at::Tensor v_render_alphas,                                      \
        at::Tensor v_means,                                                    \
        at::Tensor v_quats,                                                    \
        at::Tensor v_scales,                                                   \
        at::Tensor v_colors,                                                   \
        at::Tensor v_opacities);

    __INS__(1)
    __INS__(2)
    __INS__(3)
    __INS__(4)
    __INS__(5)
    __INS__(8)
    __INS__(9)
    __INS__(16)
    __INS__(17)
    __INS__(32)
    __INS__(33)
    __INS__(64)
    __INS__(65)
    __INS__(128)
    __INS__(129)
    __INS__(256)
    __INS__(257)
    __INS__(512)
    __INS__(513)

#undef __INS__

} // namespace gsplat
